{
    "cells": [
        {
            "cell_type": "markdown",
            "id": "e45f9b60-cd6b-4c15-958f-1feca5438128",
            "metadata": {},
            "source": [
                "# SQL Index Demo\n",
                "\n",
                "Demo where table contains context."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 1,
            "id": "fbd7317b",
            "metadata": {},
            "outputs": [],
            "source": [
                "import logging\n",
                "import sys\n",
                "\n",
                "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
                "logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "107396a9-4aa7-49b3-9f0f-a755726c19ba",
            "metadata": {},
            "outputs": [],
            "source": [
                "from llama_index import GPTSQLStructStoreIndex, SQLDatabase, SimpleDirectoryReader, WikipediaReader, Document\n",
                "from llama_index.indices.struct_store import SQLContextContainerBuilder\n",
                "from IPython.display import Markdown, display"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "77ac8d94-cd61-4869-a32b-0b2e7d18b83f",
            "metadata": {},
            "source": [
                "### Load Wikipedia Data"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "93301fcf-a52b-430c-98a3-5360e6c8fc4a",
            "metadata": {},
            "outputs": [],
            "source": [
                "# install wikipedia python package\n",
                "!pip install wikipedia"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "ba3f7e5e-cdc4-4529-bba9-db45d8457dba",
            "metadata": {},
            "outputs": [],
            "source": [
                "wiki_docs = WikipediaReader().load_data(pages=['Toronto', 'Berlin', 'Tokyo'])"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "461438c8-302d-45c5-8e69-16ad604686d1",
            "metadata": {},
            "source": [
                "### Create Database Schema"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "a370b266-66f5-4624-bbf9-2ad57f0511f8",
            "metadata": {},
            "outputs": [],
            "source": [
                "from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select, column"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "ea24f794-f10b-42e6-922d-9258b7167405",
            "metadata": {},
            "outputs": [],
            "source": [
                "engine = create_engine(\"sqlite:///:memory:\")\n",
                "metadata_obj = MetaData(bind=engine)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "b4154b29-7e23-4c26-a507-370a66186ae7",
            "metadata": {},
            "outputs": [],
            "source": [
                "# create city SQL table\n",
                "table_name = \"city_stats\"\n",
                "city_stats_table = Table(\n",
                "    table_name,\n",
                "    metadata_obj,\n",
                "    Column(\"city_name\", String(16), primary_key=True),\n",
                "    Column(\"population\", Integer),\n",
                "    Column(\"country\", String(16), nullable=False),\n",
                ")\n",
                "metadata_obj.create_all()"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "1c09089a-6bcd-48db-8120-a84c8da3f82e",
            "metadata": {
                "tags": []
            },
            "source": [
                "### Build Index with Context"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "611319e5-d3c2-4286-a84f-ed2459896c58",
            "metadata": {},
            "outputs": [],
            "source": [
                "from llama_index import GPTSQLStructStoreIndex, SQLDatabase\n",
                "from llama_index.indices.struct_store import SQLContextContainerBuilder"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "3fc2dfab-90ea-4f01-9e28-d21fdc5f0758",
            "metadata": {},
            "outputs": [],
            "source": [
                "sql_database = SQLDatabase(engine, include_tables=[\"city_stats\"])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "89f6f1d1-a022-43d7-b135-a79ec9407956",
            "metadata": {},
            "outputs": [],
            "source": [
                "sql_database.table_info"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "331ff0ce-9131-4680-a5f2-3f41c73e018e",
            "metadata": {},
            "source": [
                "We either set the context manually, or have GPT extract the context for us"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "a743f365-21c6-4eae-a2f4-fc72d4199daa",
            "metadata": {},
            "outputs": [],
            "source": [
                "# manually set context text\n",
                "city_stats_text = (\n",
                "    \"This table gives information regarding the population and country of a given city.\\n\"\n",
                "    \"The user will query with codewords, where 'foo' corresponds to population and 'bar'\"\n",
                "    \"corresponds to city.\"\n",
                ")\n",
                "table_context_dict={\"city_stats\": city_stats_text}\n",
                "context_builder = SQLContextContainerBuilder(sql_database, context_dict=table_context_dict)\n",
                "context_container = context_builder.build_context_container()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "b15ffc74-4a44-40b4-87b1-44f952ebfd58",
            "metadata": {},
            "outputs": [],
            "source": [
                "index = GPTSQLStructStoreIndex.from_documents(\n",
                "    wiki_docs, \n",
                "    sql_database=sql_database, \n",
                "    table_name=\"city_stats\",\n",
                "    sql_context_container=context_container\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "5cdcc666-4b51-4ed5-92ca-e98e8a01fdd0",
            "metadata": {},
            "outputs": [],
            "source": [
                "# extract context from a raw Document using GPT\n",
                "city_stats_text = (\n",
                "    \"This table gives information regarding the population and country of a given city.\\n\"\n",
                ")\n",
                "context_documents_dict = {\"city_stats\": [Document(city_stats_text)]}\n",
                "context_builder = SQLContextContainerBuilder.from_documents(\n",
                "    context_documents_dict, \n",
                "    sql_database\n",
                ")\n",
                "context_container = context_builder.build_context_container()\n",
                "\n",
                "index = GPTSQLStructStoreIndex.from_documents(\n",
                "    wiki_docs, \n",
                "    sql_database=sql_database, \n",
                "    table_name=\"city_stats\",\n",
                "    sql_context_container=context_container,\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "b315b8ff-7dd7-4e7d-ac47-8c5a0c3e7ae9",
            "metadata": {},
            "outputs": [],
            "source": [
                "# view current table\n",
                "stmt = select(\n",
                "    [column(\"city_name\"), column(\"population\"), column(\"country\")]\n",
                ").select_from(city_stats_table)\n",
                "\n",
                "with engine.connect() as connection:\n",
                "    results = connection.execute(stmt).fetchall()\n",
                "    print(results)\n"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "051a171f-8c97-40ed-ae17-4e3fa3785487",
            "metadata": {},
            "source": [
                "### Query Index"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "91139712-f232-47e1-9683-cbbd49cd331b",
            "metadata": {},
            "source": [
                "Here we show a natural language query, which is translated to a SQL query under the hood."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "a76d1816-4f70-4914-80af-7b968c614592",
            "metadata": {},
            "outputs": [],
            "source": [
                "# set Logging to DEBUG for more detailed outputs\n",
                "response = index.query(\"Which city has the highest population?\", mode=\"default\")"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "5dc2f7bf-6f6c-42ba-8f42-47afea6606ad",
            "metadata": {},
            "source": [
                "We can also use codewords during the NL query! "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "d71045c0-7a96-4e86-b38c-c378b7759aa4",
            "metadata": {},
            "outputs": [],
            "source": [
                "# set Logging to DEBUG for more detailed outputs\n",
                "response = index.query(\"Which bar has the highest foo?\", mode=\"default\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "e713d73e-73ed-4748-8673-f476899fac8e",
            "metadata": {},
            "outputs": [],
            "source": [
                "display(Markdown(f\"<b>{response}</b>\"))"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "Python 3 (ipykernel)",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.9.16"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 5
}